{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple SFT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "# unsloth 有个bug，会不认 Huggingface 的 model cache, 所以需要手动下载\n",
    "mkdir -p output/models/\n",
    "# 国内需要配置 Huggingface mirror\n",
    "export HF_ENDPOINT=https://hf-mirror.com\n",
    "huggingface-cli download --resume-download Qwen/Qwen2.5-1.5B-Instruct --local-dir output/models/Qwen2.5-1.5B-Instruct\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# 国内需要禁止统计，否则会卡在模型加载的地方（连不到外网）\n",
    "os.environ[\"UNSLOTH_DISABLE_STATISTICS\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "Unsloth: Failed to patch Gemma3ForConditionalGeneration.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
      "INFO 04-22 13:55:35 [__init__.py:239] Automatically detected platform cuda.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.51.3. vLLM: 0.8.4.\n",
      "   \\\\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 22.159 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]\n",
      " \"-____-\"     Free license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.\n"
     ]
    }
   ],
   "source": [
    "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
    "import torch\n",
    "\n",
    "max_seq_length = 1024 # Can increase for longer reasoning traces\n",
    "lora_rank = 64 # Larger rank = smarter, but slower\n",
    "\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name = \"/data/countdown/output/models/Qwen2.5-1.5B-Instruct\", # change to your model path\n",
    "    max_seq_length = max_seq_length,\n",
    "    load_in_4bit = False, # False for LoRA 16bit\n",
    "    local_files_only=True,\n",
    "    # fast_inference = True, # Enable vLLM fast inference\n",
    "    max_lora_rank = lora_rank,\n",
    "    # gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
    ")\n",
    "\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
    "    target_modules = [\n",
    "        \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
    "        \"gate_proj\", \"up_proj\", \"down_proj\",\n",
    "    ], # Remove QKVO if out of memory\n",
    "    lora_alpha = lora_rank * 2,\n",
    "    use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
    "    random_state = 3407,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 可以使用Unsloth框架中定义的，也可以使用模型自带的 chat template\n",
    "# 新的模型建议使用 Unsloth 框架带的，会修复模型自带的 chat template 的Bug\n",
    "from unsloth.chat_templates import get_chat_template\n",
    "tokenizer = get_chat_template(\n",
    "   tokenizer,\n",
    "   chat_template = \"qwen-2.5\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "527fbec534ae4d15807a7e2946adc85b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"json\", data_files=\"data/train_sft_simple.jsonl\")[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0203720c6954936acee56e0575acf6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def apply_chat_template(examples):\n",
    "    texts = tokenizer.apply_chat_template(examples[\"messages\"])\n",
    "    texts = [tokenizer.decode(text) for text in texts]\n",
    "    return { \"text\" : texts }\n",
    "\n",
    "dataset = dataset.map(apply_chat_template, batched = True, load_from_cache_file=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n",
      "<|im_start|>user\n",
      "Using the numbers 36, 70, 644, 97, create an equation that equals 653. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Show your work in <think> </think> tags. And return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>. Think step by step inside <think> tags.<|im_end|>\n",
      "<|im_start|>assistant\n",
      "<think>Step 1: 36 - 97 = -61\n",
      "Step 2: -61 + 644 = 583\n",
      "Step 3: 583 + 70 = 653\n",
      "Final answer: ((36 - 97) + 644) + 70</think>\n",
      "\n",
      "<answer>((36 - 97) + 644) + 70</answer><|im_end|>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(dataset[99][\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter:"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "  ········\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: No netrc file found, creating one.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mswulling\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.19.9"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/data/countdown/wandb/run-20250422_135610-gdewj8jy</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/swulling/countdown-sft-simple/runs/gdewj8jy' target=\"_blank\">sage-jazz-6</a></strong> to <a href='https://wandb.ai/swulling/countdown-sft-simple' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/swulling/countdown-sft-simple' target=\"_blank\">https://wandb.ai/swulling/countdown-sft-simple</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/swulling/countdown-sft-simple/runs/gdewj8jy' target=\"_blank\">https://wandb.ai/swulling/countdown-sft-simple/runs/gdewj8jy</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/swulling/countdown-sft-simple/runs/gdewj8jy?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7f68096845b0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import wandb\n",
    "wandb.init(project=\"countdown-sft-simple\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "938198d1630545bca4e3f739f7ed035c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Unsloth: Tokenizing [\"text\"] (num_proc=12):   0%|          | 0/1000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from trl import SFTTrainer, SFTConfig\n",
    "trainer = SFTTrainer(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    train_dataset = dataset,\n",
    "    eval_dataset = None, # Can set up evaluation!\n",
    "    args = SFTConfig(\n",
    "        dataset_text_field = \"text\",\n",
    "        per_device_train_batch_size = 8,\n",
    "        gradient_accumulation_steps = 1, # Use GA to mimic batch size!\n",
    "        warmup_ratio = 0.1,\n",
    "        num_train_epochs = 2,\n",
    "        learning_rate = 1e-4, # Reduce to 2e-5 for long training runs\n",
    "        logging_steps = 1,\n",
    "        optim = \"adamw_8bit\",\n",
    "        weight_decay = 0.01,\n",
    "        lr_scheduler_type = \"cosine\",\n",
    "        seed = 3407,\n",
    "        report_to = \"wandb\", # Use this for WandB etc\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n",
      "<|im_start|>user\n",
      "Using the numbers 36, 70, 644, 97, create an equation that equals 653. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Show your work in <think> </think> tags. And return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>. Think step by step inside <think> tags.<|im_end|>\n",
      "<|im_start|>assistant\n",
      "<think>Step 1: 36 - 97 = -61\n",
      "Step 2: -61 + 644 = 583\n",
      "Step 3: 583 + 70 = 653\n",
      "Final answer: ((36 - 97) + 644) + 70</think>\n",
      "\n",
      "<answer>((36 - 97) + 644) + 70</answer><|im_end|>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(trainer.train_dataset[99][\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# todo: train completions only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
      "   \\\\   /|    Num examples = 1,000 | Num Epochs = 2 | Total steps = 250\n",
      "O^O/ \\_/ \\    Batch size per device = 8 | Gradient accumulation steps = 1\n",
      "\\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8\n",
      " \"-____-\"     Trainable parameters = 73,859,072/1,617,573,376 (4.57% trained)\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='250' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [250/250 01:55, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.036100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.056900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>2.058000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.995800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.884200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.729500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.541300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>1.390100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>1.235100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.085800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.942000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.799700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.668700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.525400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.419900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.327500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.250500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.200200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.167900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.148800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.142300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.146100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.140300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.140700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.145500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.144700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.141700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.144100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.135200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.147400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.140300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.147700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.135100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.139200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.136000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.138400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.143800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.142200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.138000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.143500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.137600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.137700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.138300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.143900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.136700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.141500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.147400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.133900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.141800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.135100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.137000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.131100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.134700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.135900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.137800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.132400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.138100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.130300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.132600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.140600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.132700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.140200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.131300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.137200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.131000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.134800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.128900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.131000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.128600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.127700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.130900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.129800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.134000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>76</td>\n",
       "      <td>0.129700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>77</td>\n",
       "      <td>0.124600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>78</td>\n",
       "      <td>0.146300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>79</td>\n",
       "      <td>0.138500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.135500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>81</td>\n",
       "      <td>0.129900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>82</td>\n",
       "      <td>0.142700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>83</td>\n",
       "      <td>0.144000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>84</td>\n",
       "      <td>0.130500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>85</td>\n",
       "      <td>0.134600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>86</td>\n",
       "      <td>0.137200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>87</td>\n",
       "      <td>0.136300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>0.142000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>89</td>\n",
       "      <td>0.129900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.135600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>91</td>\n",
       "      <td>0.133000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>92</td>\n",
       "      <td>0.136300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>93</td>\n",
       "      <td>0.128600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>94</td>\n",
       "      <td>0.140800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>95</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>96</td>\n",
       "      <td>0.131900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>97</td>\n",
       "      <td>0.132400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>98</td>\n",
       "      <td>0.138400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>99</td>\n",
       "      <td>0.128900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.135100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>101</td>\n",
       "      <td>0.132700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>102</td>\n",
       "      <td>0.138900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>103</td>\n",
       "      <td>0.129900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>104</td>\n",
       "      <td>0.128200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>105</td>\n",
       "      <td>0.130400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>106</td>\n",
       "      <td>0.135300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>107</td>\n",
       "      <td>0.127400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>108</td>\n",
       "      <td>0.136300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>109</td>\n",
       "      <td>0.135800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>0.137100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>111</td>\n",
       "      <td>0.138400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>112</td>\n",
       "      <td>0.135400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>113</td>\n",
       "      <td>0.136800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>114</td>\n",
       "      <td>0.133100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>115</td>\n",
       "      <td>0.132200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>116</td>\n",
       "      <td>0.136500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>117</td>\n",
       "      <td>0.128100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>118</td>\n",
       "      <td>0.134900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>119</td>\n",
       "      <td>0.135200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.125800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>121</td>\n",
       "      <td>0.124800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>122</td>\n",
       "      <td>0.137500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>123</td>\n",
       "      <td>0.134900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>124</td>\n",
       "      <td>0.128900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>125</td>\n",
       "      <td>0.131000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>126</td>\n",
       "      <td>0.125000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>127</td>\n",
       "      <td>0.128600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>128</td>\n",
       "      <td>0.127600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>129</td>\n",
       "      <td>0.134100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>0.128700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>131</td>\n",
       "      <td>0.131300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>132</td>\n",
       "      <td>0.130300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>133</td>\n",
       "      <td>0.131800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>134</td>\n",
       "      <td>0.130000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>135</td>\n",
       "      <td>0.134200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>136</td>\n",
       "      <td>0.132500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>137</td>\n",
       "      <td>0.132600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>138</td>\n",
       "      <td>0.126200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>139</td>\n",
       "      <td>0.133300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>0.135700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>141</td>\n",
       "      <td>0.127100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>142</td>\n",
       "      <td>0.124300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>143</td>\n",
       "      <td>0.130100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>144</td>\n",
       "      <td>0.122400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>145</td>\n",
       "      <td>0.126500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>146</td>\n",
       "      <td>0.133400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>147</td>\n",
       "      <td>0.139500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>148</td>\n",
       "      <td>0.131900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>149</td>\n",
       "      <td>0.138000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.124800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>151</td>\n",
       "      <td>0.136400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>152</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>153</td>\n",
       "      <td>0.125700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>154</td>\n",
       "      <td>0.132300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>155</td>\n",
       "      <td>0.137500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>156</td>\n",
       "      <td>0.133300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>157</td>\n",
       "      <td>0.128500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>158</td>\n",
       "      <td>0.128900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>159</td>\n",
       "      <td>0.126500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.127100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>161</td>\n",
       "      <td>0.129200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>162</td>\n",
       "      <td>0.132800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>163</td>\n",
       "      <td>0.130900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>164</td>\n",
       "      <td>0.128300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>165</td>\n",
       "      <td>0.131700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>166</td>\n",
       "      <td>0.134300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>167</td>\n",
       "      <td>0.130600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>168</td>\n",
       "      <td>0.127000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>169</td>\n",
       "      <td>0.128000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>0.134300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>171</td>\n",
       "      <td>0.119300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>172</td>\n",
       "      <td>0.133500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>173</td>\n",
       "      <td>0.128500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>174</td>\n",
       "      <td>0.132300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>175</td>\n",
       "      <td>0.130600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>176</td>\n",
       "      <td>0.125700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>177</td>\n",
       "      <td>0.126100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>178</td>\n",
       "      <td>0.136900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>179</td>\n",
       "      <td>0.130600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.132800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>181</td>\n",
       "      <td>0.126900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>182</td>\n",
       "      <td>0.126400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>183</td>\n",
       "      <td>0.128800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>184</td>\n",
       "      <td>0.125200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>185</td>\n",
       "      <td>0.128200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>186</td>\n",
       "      <td>0.135500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>187</td>\n",
       "      <td>0.125500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>188</td>\n",
       "      <td>0.133900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>189</td>\n",
       "      <td>0.129800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>0.130300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>191</td>\n",
       "      <td>0.136400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>192</td>\n",
       "      <td>0.121900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>193</td>\n",
       "      <td>0.131300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>194</td>\n",
       "      <td>0.132500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>195</td>\n",
       "      <td>0.135000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>196</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>197</td>\n",
       "      <td>0.127500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>198</td>\n",
       "      <td>0.128200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>199</td>\n",
       "      <td>0.126100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.125100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>201</td>\n",
       "      <td>0.130900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>202</td>\n",
       "      <td>0.131200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>203</td>\n",
       "      <td>0.130500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>204</td>\n",
       "      <td>0.124100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>205</td>\n",
       "      <td>0.129800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>206</td>\n",
       "      <td>0.121900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>207</td>\n",
       "      <td>0.132500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>208</td>\n",
       "      <td>0.125200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>209</td>\n",
       "      <td>0.124400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>0.133100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>211</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>212</td>\n",
       "      <td>0.130800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>213</td>\n",
       "      <td>0.127300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>214</td>\n",
       "      <td>0.126200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>215</td>\n",
       "      <td>0.127800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>216</td>\n",
       "      <td>0.132600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>217</td>\n",
       "      <td>0.128100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>218</td>\n",
       "      <td>0.128200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>219</td>\n",
       "      <td>0.128700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>0.131400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>221</td>\n",
       "      <td>0.129300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>222</td>\n",
       "      <td>0.125000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>223</td>\n",
       "      <td>0.137100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>224</td>\n",
       "      <td>0.130200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>225</td>\n",
       "      <td>0.135000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>226</td>\n",
       "      <td>0.125800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>227</td>\n",
       "      <td>0.124200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>228</td>\n",
       "      <td>0.128800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>229</td>\n",
       "      <td>0.131600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>0.129800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>231</td>\n",
       "      <td>0.132300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>232</td>\n",
       "      <td>0.125400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>233</td>\n",
       "      <td>0.132100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>234</td>\n",
       "      <td>0.128800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>235</td>\n",
       "      <td>0.133300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>236</td>\n",
       "      <td>0.119700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>237</td>\n",
       "      <td>0.133400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>238</td>\n",
       "      <td>0.127100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>239</td>\n",
       "      <td>0.129300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>0.131900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>241</td>\n",
       "      <td>0.128500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>242</td>\n",
       "      <td>0.130700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>243</td>\n",
       "      <td>0.134800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>244</td>\n",
       "      <td>0.124500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>245</td>\n",
       "      <td>0.133600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>246</td>\n",
       "      <td>0.133500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>247</td>\n",
       "      <td>0.129300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>248</td>\n",
       "      <td>0.134500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>249</td>\n",
       "      <td>0.136400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>0.127100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unsloth: Will smartly offload gradients to save VRAM!\n"
     ]
    }
   ],
   "source": [
    "trainer_stats = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('output/qwen2.5-1.5b-sft-simple-lora/tokenizer_config.json',\n",
       " 'output/qwen2.5-1.5b-sft-simple-lora/special_tokens_map.json',\n",
       " 'output/qwen2.5-1.5b-sft-simple-lora/vocab.json',\n",
       " 'output/qwen2.5-1.5b-sft-simple-lora/merges.txt',\n",
       " 'output/qwen2.5-1.5b-sft-simple-lora/added_tokens.json',\n",
       " 'output/qwen2.5-1.5b-sft-simple-lora/tokenizer.json')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save_pretrained(\"output/qwen2.5-1.5b-sft-simple-lora\")  # Local saving lora weights\n",
    "tokenizer.save_pretrained(\"output/qwen2.5-1.5b-sft-simple-lora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "vllm inference\n",
    "```bash\n",
    "vllm serve output/models/Qwen2.5-1.5B-Instruct --port 8100 --api-key NLUKKXIJDZ91rpg1z --enforce-eager  --max-model-len 4096 --enable-lora --max-lora-rank 64 --lora-modules qwen2.5-1.5b-sft-simple-lora=output/qwen2.5-1.5b-sft-simple-lora\n",
    "\n",
    "CURATOR_VIEWER=1 python eval.py --provider vllm --data_path data/test.jsonl --model_name qwen2.5-1.5b-sft-simple-lora --temperature 0.01 --max_tokens 1024\n",
    "\n",
    "https://curator.bespokelabs.ai/datasets/c39bd5aca6674beeba73d34aad53a328\n",
    "\n",
    "Accuracy: 4/100 (4.00%)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分析：\n",
    "\n",
    "1. 格式上 SFT 后完全遵守，没有问题\n",
    "2. 不再出现数字使用错误（包括没有使用或者使用超过一次）\n",
    "3. 所有的错误都是计算错误，最终不相等\n",
    "4. 训练收敛的很快，收敛的只是格式而不是正确性\n",
    "\n",
    "\n",
    "本质还是合成的简单推理过程过于简洁，模型没有学会推理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个有意思的现象\n",
    "\n",
    "1. 刚开始数据集合成的时候有个Bug，导致数字的顺序和解答的顺序完全一致。这时候 Simple SFT 的正确率是 19%。\n",
    "2. 也就是说 Simple SFT 可以学习到比较简单的推理（顺序的），但是稍微复杂的推理基本学习不到。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
